#!/bin/env python3
import ithildin as ith
import numpy as np
import matplotlib.pyplot as plt
import sys
from typing import List
from scipy.interpolate import RectBivariateSpline
from pydiffmap.diffusion_map import DiffusionMap
from pydiffmap import visualization as diff_visualization
import matplotlib

# Courtemanche1998 model, 2 ruimtedimensies, 21 variabelen

data = ith.SimData.from_stem("myokit12/myokit_2")


# Vermijd randeffecten en initialiatie
# TODO: begintijd zou in principe in de log moeten staan
def snij_randen_weg(variabele,begintijd, eindtijd):
    def begin(index):
        if index == 0:
            return begintijd
        return variabele.shape[index]//8
    def eind(index):
        if index == 0:
            return eindtijd
        if variabele.shape[index] < 20:
            return variabele.shape[index]
        return (7 * variabele.shape[index])//8
    return variabele[begin(0):eind(0),begin(1):eind(1),begin(2):eind(2),begin(3):eind(3)]

def snij_randen_weg2(variables,begintijd,eindtijd):
    newdata = dict()
    for key in variables.keys():
        newdata[key] = snij_randen_weg(variables[key], begintijd,eindtijd)
    return newdata

t0 = 1
t1 = 150
aantal_punten = 15000
vars = snij_randen_weg2(data.vars, t0, t1)
for key in vars.keys():
    vars[key] = vars[key].ravel()

def plot4(x,y,z,w,autoclose=False):
    choices = np.random.choice(vars[x].shape[0], aantal_punten, replace=True)
    var1 = vars[x][choices]
    var2 = vars[y][choices]
    var3 = vars[z][choices]
    var4 = vars[w][choices]
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.scatter(var1, var2, var3, c=var4, cmap='viridis')
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    ax.set_zlabel(z)
    plt.axis('tight')
    #plt.savefig("Courtemanche2D (%s, %s, %s -> %s).png" % (x, y, z, w))
    if not autoclose:
        plt.show()
    plt.close()

matplotlib.use("TkAgg")
# Choose random combinations of vour variables.
keys = list(vars.keys())
assert(np.std(data.vars['K']) == 0) # it is constant, so don't plot it (the contrentation of K+ ions remains constant?)
keys.remove('K')
keys.sort()
np.random.seed(1234)
for i in range(0,500):
    which_keys = np.random.choice(keys, 4, replace=False)
    print(which_keys)
    plot4(which_keys[0], which_keys[1], which_keys[2], which_keys[3])
#    , autoclose=True)

plot4("gatexr", "gatev", "gatew", "Ca")
# gateh, Na, gateu, ??
# gated, >Ca, gateexr

# PDF cOURTEMANCHE: https://journals-physiology-org.kuleuven.e-bronnen.be/doi/epdf/10.1152/ajpheart.1998.275.1.H301
# ajpheart.1998.275.1.h301

# ['gateoi' 'gatew' 'gateoa' 'gatefCa'] # saai
# ['gateui' 'Ca' 'Carel' 'gated'] # saai
# ['gateh' 'gateu' 'Na' 'Caup']
plot4('gateh', 'gateu', 'Na', 'Caup')
plot4('gateh', 'gateu',  'Caup', 'Na')
plot4('gateh', 'Caup', 'gateu', 'Na')
plot4('Caup', 'gateh', 'gateu', 'Na')

# rprobeer: gateu, Caup, ???, ??
np.random.seed(1234)
for i in range(0,40):
    keys2 = list(keys)
    keys2.remove('gateu')
    keys2.remove('Caup')
    which_keys2 = np.random.choice(keys2, 2, replace=False)
    print (['gateu', 'Caup', which_keys2[0], which_keys2[1]])
    plot4('gateu', 'Caup', which_keys2[0], which_keys2[1])

    #['gateu', 'Caup', 'gatexs', 'gateoa'] --> complex, still 2D? (maybe check with diffmap
    #['gateu', 'Caup', 'gateoi', 'gateh'] --> closed, non-weird
    #['gateu', 'Caup', 'gateui', 'gatev'] --> not closed, complex, still 2D, maybe
    #['gateu', 'Caup', 'gateui', 'gateh'] --> not closed, complex, still 2D, maybe
    #['gateu', 'Caup', 'gateui', 'Ca'] --> complex, maybe closed,
    #['gateu', 'Caup', 'gateoa', 'gateoi'] --> maybe 2D, but check
    #['gateu', 'Caup', 'gatev', 'gatefCa'] ---> maybe 3D, check!
    #['gateu', 'Caup', 'gatew', 'gatef'] --> check
